from typing import Tuple

from gym.envs.registration import register
import numpy as np
from typing import Tuple, Dict, Text
from highway_env import utils
from highway_env.envs.common.abstract import AbstractEnv
from highway_env.road.lane import LineType, StraightLane, CircularLane, SineLane
from highway_env.road.road import Road, RoadNetwork
from highway_env.vehicle.controller import MDPVehicle


class RoundaboutEnv(AbstractEnv):

    @classmethod
    def default_config(cls) -> dict:
        config = super().default_config()
        config.update({
            "observation": {
                "type": "Kinematics",
                # "horizon": 5,
                # "absolute": False,
                # "features": [ "x", "y", "vx", "vy", "lane_heading_difference" ],  #"lane_heading_difference"   #"cos_h", "sin_h",
                # "normalize": False,
                "features_range": {"x": [-100, 100], "y": [-100, 100], "vx": [-25, 25], "vy": [-15, 15]},
            },
            "action": {
                # "type": "DiscreteMetaAction",
                # #"target_speeds": [0, 2, 4]
                "type": "ContinuousAction",
                "speed_range": [2, 5],
                # "target_speeds": [0, 5, 10]
            },
            "incoming_vehicle_destination": None,
            "collision_reward": -1,
            "high_speed_reward": 0.2,
            "right_lane_reward": 0,
            "policy_frequency":5,
            "lane_change_reward": -0.05,
            "screen_width": 1000,
            "screen_height": 600,
            "centering_position": [0.5, 0.6],
            "duration": 60,
            "normalize_reward": True,
            # "lane_centering_cost": 1,
            # "lane_centering_reward": 1,
            # "lane_centering_weight": 1,
            # "speed_weight":1,
            # "progress_weight": 1,

        })
        return config

    def _reward(self, action: np.ndarray) -> float:
        rewards = self._rewards(action)
        # 按权重计算不同奖励项的贡献
        reward = (rewards["efficiency"] * self.config.get("efficiency_weight", 0.25) +
                  rewards["safe_driving"] * self.config.get("safe_driving_weight", 0.25) +
                  rewards["comfort"] * self.config.get("comfort_weight", 0.5) +
                  rewards["lane_centering_reward"] * self.config.get("lane_centering_reward", 1))
        reward *= rewards["on_road_reward"]
        return reward

    def _rewards(self, action: np.ndarray) -> Dict[Text, float]:
        # 车道中心奖励
        _, lateral = self.vehicle.lane.local_coordinates(self.vehicle.position)
        lane_centering_reward = 1 - (lateral / self.vehicle.lane.width) ** 2
        # 安全奖励：避免碰撞和超速
        safe_driving_reward = self.config.get("collision_reward", -1) * self.vehicle.crashed + self.config.get("speed_reward", 1) * self.vehicle.speed / 10
        # 效率奖励
        speed_limit = 8
        speed_max = 10
        if self.vehicle.speed <= speed_limit:
            efficiency_rewarad = self.vehicle.speed/speed_limit
        else:
            efficiency_rewarad =1 - self.vehicle.speed/speed_max
        # 舒适性奖励：减小转向角度和加速度
        comfort_reward = 1-(abs(action[0])/2-abs(action[1])/2)/2
        # 效率奖励：尽快到达目的地
        # progress_reward = self.vehicle.speed_index * self.config.get("progress_reward", 1)
        distance = self.vehicle.lane.length - self.vehicle.lane.local_coordinates(self.vehicle.position)[0]
        target_distance = self.config.get("target_distance", 100)
        progress_reward = (target_distance - distance) / target_distance

        # 将奖励项标准化到[0, 1]区间
        safe_driving_reward = utils.lmap(safe_driving_reward, [-1, 0], [0, 1])
        comfort_reward = utils.lmap(comfort_reward, [-1, 0], [0, 1])
        # progress_reward = utils.lmap( progress_reward,[0, 1])
        lane_centering_reward = utils.lmap(lane_centering_reward, [-1, 0], [0, 1])
        return {"lane_centering_reward": lane_centering_reward,
                "progress": progress_reward,
                "on_road_reward": self.vehicle.on_road,
                "safe_driving": safe_driving_reward,
                "comfort": comfort_reward,
                "efficiency":  efficiency_rewarad,
                }

    # def _reward(self, action: np.ndarray) -> float:
    #     rewards = self._rewards(action)
    #     reward = rewards["lane_centering_reward"] * self.config.get("lane_centering_weight", 1)
    #     return reward
    #
    # def _rewards(self, action: np.ndarray) -> Dict[Text, float]:
    #     _, lateral = self.vehicle.lane.local_coordinates(self.vehicle.position)
    #     lane_deviation = abs(lateral)
    #     lane_centering_reward = -self.config.get("lane_centering_cost", 10) * lane_deviation
    #     return {"lane_centering_reward": lane_centering_reward}

    # def _is_terminal(self) -> bool:
    #     """The episode is over when a collision occurs or when the access ramp has been passed."""
    #     return self.vehicle.crashed or self.time >= self.config["duration"]
    def _is_terminated(self) -> bool:
        return self.vehicle.crashed

    def _is_truncated(self) -> bool:
        return self.time >= self.config["duration"]

    def _reset(self) -> None:
        self._make_road()
        self._make_vehicles()

    def _make_road(self) -> None:
        # Circle lanes: (s)outh/(e)ast/(n)orth/(w)est (e)ntry/e(x)it.
        center = [0, 0]  # [m]
        radius = 20  # [m]
        alpha = 24  # [deg]

        net = RoadNetwork()
        radii = [radius, radius + 4]
        n, c, s = LineType.NONE, LineType.CONTINUOUS, LineType.STRIPED
        line = [[c, s], [n, c]]
        for lane in [0, 1]:
            net.add_lane("se", "ex",
                         CircularLane(center, radii[lane], np.deg2rad(90 - alpha), np.deg2rad(alpha),
                                      clockwise=False, line_types=line[lane]))
            net.add_lane("ex", "ee",
                         CircularLane(center, radii[lane], np.deg2rad(alpha), np.deg2rad(-alpha),
                                      clockwise=False, line_types=line[lane]))
            net.add_lane("ee", "nx",
                         CircularLane(center, radii[lane], np.deg2rad(-alpha), np.deg2rad(-90 + alpha),
                                      clockwise=False, line_types=line[lane]))
            net.add_lane("nx", "ne",
                         CircularLane(center, radii[lane], np.deg2rad(-90 + alpha), np.deg2rad(-90 - alpha),
                                      clockwise=False, line_types=line[lane]))
            net.add_lane("ne", "wx",
                         CircularLane(center, radii[lane], np.deg2rad(-90 - alpha), np.deg2rad(-180 + alpha),
                                      clockwise=False, line_types=line[lane]))
            net.add_lane("wx", "we",
                         CircularLane(center, radii[lane], np.deg2rad(-180 + alpha), np.deg2rad(-180 - alpha),
                                      clockwise=False, line_types=line[lane]))
            net.add_lane("we", "sx",
                         CircularLane(center, radii[lane], np.deg2rad(180 - alpha), np.deg2rad(90 + alpha),
                                      clockwise=False, line_types=line[lane]))
            net.add_lane("sx", "se",
                         CircularLane(center, radii[lane], np.deg2rad(90 + alpha), np.deg2rad(90 - alpha),
                                      clockwise=False, line_types=line[lane]))

        # Access lanes: (r)oad/(s)ine
        access = 170  # [m]
        dev = 85  # [m]
        a = 5  # [m]
        delta_st = 0.2 * dev  # [m]

        delta_en = dev - delta_st
        w = 2 * np.pi / dev
        net.add_lane("ser", "ses", StraightLane([2, access], [2, dev / 2], line_types=(s, c)))
        net.add_lane("ses", "se",
                     SineLane([2 + a, dev / 2], [2 + a, dev / 2 - delta_st], a, w, -np.pi / 2, line_types=(c, c)))
        net.add_lane("sx", "sxs",
                     SineLane([-2 - a, -dev / 2 + delta_en], [-2 - a, dev / 2], a, w, -np.pi / 2 + w * delta_en,
                              line_types=(c, c)))
        net.add_lane("sxs", "sxr", StraightLane([-2, dev / 2], [-2, access], line_types=(n, c)))

        net.add_lane("eer", "ees", StraightLane([access, -2], [dev / 2, -2], line_types=(s, c)))
        net.add_lane("ees", "ee",
                     SineLane([dev / 2, -2 - a], [dev / 2 - delta_st, -2 - a], a, w, -np.pi / 2, line_types=(c, c)))
        net.add_lane("ex", "exs",
                     SineLane([-dev / 2 + delta_en, 2 + a], [dev / 2, 2 + a], a, w, -np.pi / 2 + w * delta_en,
                              line_types=(c, c)))
        net.add_lane("exs", "exr", StraightLane([dev / 2, 2], [access, 2], line_types=(n, c)))

        net.add_lane("ner", "nes", StraightLane([-2, -access], [-2, -dev / 2], line_types=(s, c)))
        net.add_lane("nes", "ne",
                     SineLane([-2 - a, -dev / 2], [-2 - a, -dev / 2 + delta_st], a, w, -np.pi / 2, line_types=(c, c)))
        net.add_lane("nx", "nxs",
                     SineLane([2 + a, dev / 2 - delta_en], [2 + a, -dev / 2], a, w, -np.pi / 2 + w * delta_en,
                              line_types=(c, c)))
        net.add_lane("nxs", "nxr", StraightLane([2, -dev / 2], [2, -access], line_types=(n, c)))

        net.add_lane("wer", "wes", StraightLane([-access, 2], [-dev / 2, 2], line_types=(s, c)))
        net.add_lane("wes", "we",
                     SineLane([-dev / 2, 2 + a], [-dev / 2 + delta_st, 2 + a], a, w, -np.pi / 2, line_types=(c, c)))
        net.add_lane("wx", "wxs",
                     SineLane([dev / 2 - delta_en, -2 - a], [-dev / 2, -2 - a], a, w, -np.pi / 2 + w * delta_en,
                              line_types=(c, c)))
        net.add_lane("wxs", "wxr", StraightLane([-dev / 2, -2], [-access, -2], line_types=(n, c)))

        road = Road(network=net, np_random=self.np_random, record_history=self.config["show_trajectories"])
        self.road = road

    def _make_vehicles(self) -> None:
        """
        Populate a road with several vehicles on the highway and on the merging lane, as well as an ego-vehicle.

        :return: the ego-vehicle
        """
        position_deviation = 2
        speed_deviation = 2

        # Ego-vehicle
        ego_lane = self.road.network.get_lane(("wer", "wes", 0))
        ego_vehicle = self.action_type.vehicle_class(self.road,
                                                     ego_lane.position(85, 0),
                                                     speed=1,
                                                     heading=ego_lane.heading_at(0))
        try:
            ego_vehicle.plan_route_to("sxr")
        except AttributeError:
            pass
        self.road.vehicles.append(ego_vehicle)
        self.vehicle = ego_vehicle
        # self.controlled_vehicles.append(ego_vehicle)
        # Incoming vehicle
        destinations = ["exr", "sxr", "nxr"]
        other_vehicles_type = utils.class_from_path(self.config["other_vehicles_type"])
        vehicle = other_vehicles_type.make_on_lane(self.road,
                                                   ("we", "sx", 1),
                                                   longitudinal=5,# + self.np_random.random()*position_deviation,
                                                   speed=5 )#+ self.np_random.random() * speed_deviation)

        if self.config["incoming_vehicle_destination"] is not None:
            destination = destinations[self.config["incoming_vehicle_destination"]]
        else:
            destination = self.np_random.choice(destinations)
        vehicle.plan_route_to(destination)
        vehicle.randomize_behavior()
        #self.road.vehicles.append(vehicle)

        # Other vehicles
        for i in list(range(1, 1)) + list(range(-1, 0)):
            vehicle = other_vehicles_type.make_on_lane(self.road,
                                                       ("we", "sx", 0),
                                                       longitudinal=20*i + self.np_random.random()*position_deviation,
                                                       speed=5 + self.np_random.random() * speed_deviation)
            vehicle.plan_route_to(self.np_random.choice(destinations))
            vehicle.randomize_behavior()
            #self.road.vehicles.append(vehicle)

        # Entering vehicle
        vehicle = other_vehicles_type.make_on_lane(self.road,
                                                   ("ser", "ses", 0),
                                                   longitudinal=50 ,#+ self.np_random.random() * position_deviation,
                                                   speed=5.72313) #6.7313+ self.np_random.random() * speed_deviation)
        vehicle.plan_route_to("nxr")
        vehicle.randomize_behavior()
        self.road.vehicles.append(vehicle)


register(
    id='roundabout-v0',
    entry_point='highway_env.envs:RoundaboutEnv',
)
